In [1]:
import numpy as np
np.random.seed(2017)

import torch
torch.manual_seed(2017)

from scipy.misc import logsumexp # Use it for reference checking implementation

In [2]:
seq_length, num_states=4, 2
emissions = np.random.randint(20, size=(seq_length,num_states))*1.
transitions = np.random.randint(10, size=(num_states, num_states))*1.
print("Emissions:", emissions, sep="\n")
print("Transitions:", transitions, sep="\n")


Emissions:
[[  9.   6.]
 [ 13.  10.]
 [  8.  18.]
 [  3.  15.]]
Transitions:
[[ 7.  8.]
 [ 0.  8.]]

In [3]:
def viterbi_decoding(emissions, transitions):
    # Use help from: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/crf/python/ops/crf.py
    scores = np.zeros_like(emissions)
    back_pointers = np.zeros_like(emissions, dtype="int")
    scores = emissions[0]
    # Generate most likely scores and paths for each step in sequence
    for i in range(1, emissions.shape[0]):
        score_with_transition = np.expand_dims(scores, 1) + transitions
        scores = emissions[i] + score_with_transition.max(axis=0)
        back_pointers[i] = np.argmax(score_with_transition, 0)
    # Generate the most likely path
    viterbi = [np.argmax(scores)]
    for bp in reversed(back_pointers[1:]):
        viterbi.append(bp[viterbi[-1]])
    viterbi.reverse()
    viterbi_score = np.max(scores)
    return viterbi_score, viterbi

In [4]:
viterbi_decoding(emissions, transitions)


Out[4]:
(78.0, [0, 0, 1, 1])

In [5]:
def viterbi_decoding_torch(emissions, transitions):
    scores = torch.zeros(emissions.size(1))
    back_pointers = torch.zeros(emissions.size()).int()
    scores = scores + emissions[0]
    # Generate most likely scores and paths for each step in sequence
    for i in range(1, emissions.size(0)):
        scores_with_transitions = scores.unsqueeze(1).expand_as(transitions) + transitions
        max_scores, back_pointers[i] = torch.max(scores_with_transitions, 0)
        scores = emissions[i] + max_scores
    # Generate the most likely path
    viterbi = [scores.numpy().argmax()]
    back_pointers = back_pointers.numpy()
    for bp in reversed(back_pointers[1:]):
        viterbi.append(bp[viterbi[-1]])
    viterbi.reverse()
    viterbi_score = scores.numpy().max()
    return viterbi_score, viterbi

In [6]:
viterbi_decoding_torch(torch.Tensor(emissions), torch.Tensor(transitions))


Out[6]:
(78.0, [0, 0, 1, 1])

In [7]:
viterbi_decoding(emissions, transitions)


Out[7]:
(78.0, [0, 0, 1, 1])

In [8]:
def log_sum_exp(vecs, axis=None, keepdims=False):
    ## Use help from: https://github.com/scipy/scipy/blob/v0.18.1/scipy/misc/common.py#L20-L140
    max_val = vecs.max(axis=axis, keepdims=True)
    vecs = vecs - max_val
    if not keepdims:
        max_val = max_val.squeeze(axis=axis)
    out_val = np.log(np.exp(vecs).sum(axis=axis, keepdims=keepdims))
    return max_val + out_val

In [9]:
def score_sequence(emissions, transitions, tags):
    # Use help from: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/crf/python/ops/crf.py
    score = emissions[0][tags[0]]
    for i, emission in enumerate(emissions[1:]):
        score = score + transitions[tags[i], tags[i+1]] + emission[tags[i+1]]
    return score

In [10]:
score_sequence(emissions, transitions, [1,1,0,0])


Out[10]:
42.0

In [11]:
correct_seq = [0, 0, 1, 1]
[transitions[correct_seq[i],correct_seq[i+1]] for i in range(len(correct_seq) -1)]


Out[11]:
[7.0, 8.0, 8.0]

In [12]:
sum([transitions[correct_seq[i], correct_seq[i+1]] for i in range(len(correct_seq) -1)])


Out[12]:
23.0

In [13]:
viterbi_decoding(emissions, transitions)


Out[13]:
(78.0, [0, 0, 1, 1])

In [14]:
score_sequence(emissions, transitions, [0, 0, 1, 1])


Out[14]:
78.0

In [15]:
def score_sequence_torch(emissions, transitions, tags):
    score = emissions[0][tags[0]]
    for i, emission in enumerate(emissions[1:]):
        score = score + transitions[tags[i], tags[i+1]] + emission[tags[i+1]]
    return score

In [16]:
score_sequence_torch(torch.Tensor(emissions), torch.Tensor(transitions), [0, 0, 1, 1])


Out[16]:
78.0

In [17]:
def get_all_tags(seq_length, num_labels):
    if seq_length == 0:
        yield []
        return
    for sequence in get_all_tags(seq_length-1, num_labels):
        #print(sequence, seq_length)
        for label in range(num_labels):
            yield [label] + sequence        
list(get_all_tags(4,2))


Out[17]:
[[0, 0, 0, 0],
 [1, 0, 0, 0],
 [0, 1, 0, 0],
 [1, 1, 0, 0],
 [0, 0, 1, 0],
 [1, 0, 1, 0],
 [0, 1, 1, 0],
 [1, 1, 1, 0],
 [0, 0, 0, 1],
 [1, 0, 0, 1],
 [0, 1, 0, 1],
 [1, 1, 0, 1],
 [0, 0, 1, 1],
 [1, 0, 1, 1],
 [0, 1, 1, 1],
 [1, 1, 1, 1]]

In [18]:
def get_all_tags_dp(seq_length, num_labels):
    prior_tags = [[]]
    for i in range(1, seq_length+1):
        new_tags = []
        for label in range(num_labels):
            for tags in prior_tags:
                new_tags.append([label] + tags)
        prior_tags = new_tags
    return new_tags
list(get_all_tags_dp(2,2))


Out[18]:
[[0, 0], [0, 1], [1, 0], [1, 1]]

In [19]:
def brute_force_score(emissions, transitions):
    # This is for ensuring the correctness of the dynamic programming method.
    # DO NOT run with very high values of number of labels or sequence lengths
    for tags in get_all_tags_dp(*emissions.shape):
        yield score_sequence(emissions, transitions, tags)

        
brute_force_sequence_scores = list(brute_force_score(emissions, transitions))
print(brute_force_sequence_scores)


[54.0, 67.0, 58.0, 78.0, 45.0, 58.0, 56.0, 76.0, 44.0, 57.0, 48.0, 68.0, 42.0, 55.0, 53.0, 73.0]

In [20]:
max(brute_force_sequence_scores) # Best score calcuated using brute force


Out[20]:
78.0

In [21]:
log_sum_exp(np.array(brute_force_sequence_scores)) # Partition function


Out[21]:
78.132899613126483

In [22]:
def forward_algorithm_naive(emissions, transitions):
    scores = emissions[0]
    # Get the log sum exp score
    for i in range(1,emissions.shape[0]):
        print(scores)
        alphas_t = np.zeros_like(scores) # Forward vars at timestep t
        for j in range(emissions.shape[1]):
            emit_score = emissions[i,j]
            trans_score = transitions.T[j]
            next_tag_var = scores + trans_score
            alphas_t[j] = log_sum_exp(next_tag_var) + emit_score
        scores = alphas_t
    return log_sum_exp(scores)

In [23]:
forward_algorithm_naive(emissions, transitions)


[ 9.  6.]
[ 29.0000454   27.04858735]
[ 44.00017494  55.13288499]
Out[23]:
78.132899613126483

In [24]:
def forward_algorithm_vec_check(emissions, transitions):
    # This is for checking the correctedness of log_sum_exp function compared to scipy
    scores = emissions[0]
    scores_naive = emissions[0]
    # Get the log sum exp score
    for i in range(1, emissions.shape[0]):
        print(scores, scores_naive)
        scores = emissions[i] + logsumexp(
            scores_naive + transitions.T,
            axis=1)
        scores_naive = emissions[i] + np.array([log_sum_exp(
            scores_naive + transitions.T[j]) for j in range(emissions.shape[1])])
    print(scores, scores_naive)
    return logsumexp(scores), log_sum_exp(scores_naive)

In [25]:
forward_algorithm_vec_check(emissions, transitions)


[ 9.  6.] [ 9.  6.]
[ 29.0000454   27.04858735] [ 29.0000454   27.04858735]
[ 44.00017494  55.13288499] [ 44.00017494  55.13288499]
[ 58.14879707  78.13289961] [ 58.14879707  78.13289961]
Out[25]:
(78.132899613126483, 78.132899613126483)

In [26]:
def forward_algorithm(emissions, transitions):
    scores = emissions[0]
    # Get the log sum exp score
    for i in range(1, emissions.shape[0]):
        scores = emissions[i] + log_sum_exp(
            scores + transitions.T,
            axis=1)
    return log_sum_exp(scores)

In [27]:
forward_algorithm(emissions, transitions)


Out[27]:
78.132899613126483

In [28]:
tt = torch.Tensor(emissions)
tt_max, _ = tt.max(1)

In [29]:
tt_max.expand_as(tt)


Out[29]:
  9   9
 13  13
 18  18
 15  15
[torch.FloatTensor of size 4x2]

In [30]:
tt.sum(0)


Out[30]:
 33  49
[torch.FloatTensor of size 1x2]

In [31]:
tt.squeeze(0)


Out[31]:
  9   6
 13  10
  8  18
  3  15
[torch.FloatTensor of size 4x2]

In [32]:
tt.transpose(-1,-2)


Out[32]:
  9  13   8   3
  6  10  18  15
[torch.FloatTensor of size 2x4]

In [33]:
tt.ndimension()


Out[33]:
2

In [34]:
def log_sum_exp_torch(vecs, axis=None):
    ## Use help from: http://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html#sphx-glr-beginner-nlp-advanced-tutorial-py
    if axis < 0:
        axis = vecs.ndimension()+axis
    max_val, _ = vecs.max(axis)
    vecs = vecs - max_val.expand_as(vecs)
    out_val = torch.log(torch.exp(vecs).sum(axis))
    #print(max_val, out_val)
    return max_val + out_val

In [35]:
def forward_algorithm_torch(emissions, transitions):
    scores = emissions[0]
    # Get the log sum exp score
    transitions = transitions.transpose(-1,-2)
    for i in range(1, emissions.size(0)):
        scores = emissions[i] + log_sum_exp_torch(
            scores.expand_as(transitions) + transitions,
            axis=1)
    return log_sum_exp_torch(scores, axis=-1)

In [36]:
forward_algorithm_torch(torch.Tensor(emissions), torch.Tensor(transitions))


Out[36]:
 78.1329
[torch.FloatTensor of size 1]

The core idea is to find the sequence of states $y = \{y_0, y_1, ..., y_N\}$ which have the highest probability given the input $X = \{X_0, X_1, ..., X_N\}$ as follows:

$$ \begin{equation} p(y\mid X) = \prod_{i=0}^{N}{p(y_i\mid X_i)p(y_i \mid y_{i-1})}\\ \log{p(y\mid X)} = \sum_{i=0}^{N}{\log{p(y_i\mid X_i)} + \log{p(y_i \mid y_{i-1})}}\\ \end{equation} $$

Now $\log{p(y_i\mid X_i)}$ and $\log{p(y_i \mid y_{i-1})}$ can be parameterized as follows:

$$ \begin{equation} \log{p(y_i\mid X_i)} = \sum_{l=0}^{L}{\sum_{k=0}^{K}{w_{k}^{l}*\phi_{k}^{l}(X_i, y_i)}}\\ \log{p(y_i\mid y_{y-1})} = \sum_{l=0}^{L}{\sum_{l'=0}^{L}{w_{l'}^{l}*\psi_{l'}^{l}(y_i, y_{i-1})}}\\ \implies \log{p(y\mid X)} = \sum_{i=0}^{N}{(\sum_{l=0}^{L}{\sum_{k=0}^{K}{w_{k}^{l}*\phi_{k}^{l}(X_i, y_i)}} + \sum_{l=0}^{L}{\sum_{l'=0}^{L}{w_{l'}^{l}*\psi_{l'}^{l}(y_i, y_{i-1})}})}\\ \implies \log{p(y\mid X)} = \sum_{i=0}^{N}{(\Phi(X_i)W_{emission} + \log{p(y_{i-1} \mid X_{i-1})}W_{transition})} \end{equation} $$

Where,

  • $N$ is the sequence length
  • $K$ is number of feature functions,
  • $L$ is number of states
  • $W_{emission}$ is $K*L$ matrix
  • $W_{transition}$ is $L*L$ matrix
  • $\Phi(X_i)$ is a feature vector of shape $1*K$
  • $(\Phi(X_i)W_{emission} + \log{p(y_{i-1} \mid X_{i-1})}W_{transition})$ gives the score for each label

In [ ]: